package com.lw.leetcode.linked.c;


import java.util.*;

/**
 * c
 * tree
 * 834. 树中距离之和
 *
 * @author liw
 * @version 1.0
 * @date 2021/7/19 17:50
 */
public class SumOfDistancesInTree {

    public static void main(String[] args) {
        SumOfDistancesInTree test = new SumOfDistancesInTree();

        // [8, 12, 6, 10, 10, 10]
//        int[][] arr = {{0, 1}, {0, 2}, {2, 3}, {2, 4}, {2, 5}};
//        int n = 6;

        // [2, 3, 3]
//        int[][] arr = {{2, 0}, {1, 0}};
//        int n = 3;

        // [3, 3, 2]
//        int[][] arr = {{2, 1}, {0, 2}};
//        int n = 3;

        // 4
        //[[1,2],[3,2],[3,0]]

        //[6,6,4,4]
        int[][] arr = {{1, 2}, {3, 2}, {3, 0}};
        int n = 4;

        int[] ints = test.sumOfDistancesInTree(n, arr);
        System.out.println(Arrays.toString(ints));

    }

    private int[] sums;
    private int[] counts;
    private int[] values;
    private Map<Integer, List<Integer>> map;
    private int n;
    private int[][] edges;

    public int[] sumOfDistancesInTree(int n, int[][] edges) {
        if (n == 1) {
            return new int[]{0};
        }
        this.edges = edges;
        int item = 0;
        Map<Integer, List<Integer>> map = new HashMap<>();
        int length = edges.length;
        for (int i = 0; i < length; i++) {
            map.computeIfAbsent(edges[i][0], v -> new ArrayList<>()).add(i);
            map.computeIfAbsent(edges[i][1], v -> new ArrayList<>()).add(i);
        }

        this.map = map;
        change(n, 0);
        map = new HashMap<>();
        for (int[] edge : edges) {
            int a = edge[0];
            int b = edge[1];
            List<Integer> list = map.computeIfAbsent(a, v -> new ArrayList<>());
            list.add(b);
        }
        sums = new int[n];
        counts = new int[n];
        values = new int[n];
        this.map = map;
        this.n = n;
        find(item);
        find(item, sums[item] + (counts[item] << 1) - n);
        return values;
    }

    private void change(int f, int t) {
        List<Integer> list = map.get(t);
        if (list != null) {
            for (int index : list) {
                int[] edge = edges[index];
                if (edge[1] == t && edge[0] != f) {
                    edge[1] = edge[0];
                    edge[0] = t;
                }
                if (edge[0] == t) {
                    change(t, edge[1]);
                }
            }
        }
    }

    private void find(int item, int sum) {
        List<Integer> list = map.get(item);
        values[item] = sum + n - (counts[item] << 1);
        if (list != null) {
            for (int v : list) {
                find(v, values[item]);
            }
        }
    }

    private void find(int item) {
        List<Integer> list = map.get(item);
        if (list != null) {
            for (int v : list) {
                find(v);
            }
            for (int v : list) {
                counts[item] += counts[v];
                sums[item] += (sums[v] + counts[v]);
            }
            counts[item]++;
        } else {
            counts[item] = 1;
        }
    }


}
